import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.data import Data
import matplotlib.pyplot as plt

# 1. 生成随机图数据
num_nodes = 100
x = torch.rand((num_nodes, 2))  # 100 个节点，每个节点有 2 维特征
y = (x[:, 0] + x[:, 1] > 1).long()  # 二分类标签（0 或 1）

# 2. 生成图结构（邻接关系）
edge_index = []
for i in range(num_nodes):
    for j in range(i + 1, num_nodes):
        if (y[i] == y[j] and torch.rand(1).item() > 0.6) or (y[i] != y[j] and torch.rand(1).item() > 0.9):
            edge_index.append([i, j])
            edge_index.append([j, i])
edge_index = torch.tensor(edge_index, dtype=torch.long).t()

# 3. 训练集和测试集
train_mask = torch.rand(num_nodes) < 0.8  # 80% 训练，20% 测试
test_mask = ~train_mask

# 4. 构造 PyG 数据对象
data = Data(x=x, edge_index=edge_index, y=y, train_mask=train_mask, test_mask=test_mask)


# 5. 定义 4 层 GCN 模型
class GCN(torch.nn.Module):
    def __init__(self):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(2, 16)
        self.conv2 = GCNConv(16, 16)
        self.conv3 = GCNConv(16, 16)  # 将 conv3 输出改为与输入维度相同
        self.conv4 = GCNConv(16, 2)  # 输出类别数 2

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = F.relu(self.conv1(x, edge_index))
        x = F.relu(self.conv2(x, edge_index))
        x = F.relu(self.conv3(x, edge_index)) + x  # 跳跃连接，维度一致
        x = self.conv4(x, edge_index)
        return F.log_softmax(x, dim=1)  # 输出对数概率


# 6. 训练模型
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = GCN().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=500, gamma=0.5)  # 学习率衰减

data = data.to(device)
num_epochs = 2000  # 增加训练轮数

for epoch in range(num_epochs):
    model.train()
    optimizer.zero_grad()
    out = model(data)
    loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()
    scheduler.step()  # 逐步降低学习率

    if epoch % 200 == 0:
        print(f"Epoch {epoch}, Loss: {loss.item():.4f}")

# 7. 评估模型
model.eval()
out = model(data)
pred = out.argmax(dim=1)  # 取最大值的索引作为类别
test_pred = pred[data.test_mask]
test_true = data.y[data.test_mask]

# 8. 过滤低置信度预测
proba = torch.exp(out)  # 转换为 softmax
test_pred[proba[data.test_mask].max(dim=1)[0] < 0.6] = -1  # 低置信度设为 -1

# 9. 可视化测试结果
test_mask_np = torch.arange(num_nodes)[data.test_mask].cpu().numpy()
test_pred_np = test_pred.cpu().numpy()
test_true_np = test_true.cpu().numpy()

plt.figure(figsize=(10, 5))
plt.scatter(test_mask_np, test_pred_np, color='blue', alpha=0.5, label='Predicted')
plt.scatter(test_mask_np, test_true_np, color='red', alpha=0.5, label='True')
plt.xlabel('Test Node Index')
plt.ylabel('Node Class')
plt.title('Test Results vs True Results')
plt.legend()
plt.show()
